{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "662360be-0285-4885-a477-2a563b6cd6ed",
   "metadata": {},
   "source": [
    "### 准备数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68790045-f219-4831-8aeb-04c316ee7f7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "# 读取.parquet文件\n",
    "parquet_file = '/path/file_name.parquet'\n",
    "df = pd.read_parquet(parquet_file)\n",
    "\n",
    "# 获取text列的前1万条数据，只用10000条来做测试\n",
    "text_col = df['text'][:10000]\n",
    "\n",
    "# 指定要写入的txt文件\n",
    "txt_file = '/path/file_name.txt'\n",
    "\n",
    "# 将数据追加写入txt文件\n",
    "with open(txt_file, 'a') as file:\n",
    "    content_col.to_csv(file, sep='\\t', index=False, header=False)\n",
    "print(f'前1万条content数据已写入到 {txt_file}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f68c7254-dc01-4835-83e3-fec476587a28",
   "metadata": {},
   "source": [
    "### 开始训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fbdb3dc6-c0bb-4820-8b02-661e95a05f88",
   "metadata": {},
   "outputs": [],
   "source": [
    "pip install sentencepiece\n",
    "\n",
    "nohup spm_train --input '/path/file_name.txt' \\\n",
    "--input_format text \\\n",
    "--model_prefix bpe_test \\\n",
    "--model_type bpe \\\n",
    "--vocab_size 10000 \\\n",
    "--character_coverage 0.9995 \\\n",
    "--num_threads 32 \\\n",
    "--split_digits True \\\n",
    "--byte_fallback True \\\n",
    "--max_sentence_length 24000 > bpe_test.log &"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "efda410b-6836-4638-a70d-d77f188be2f3",
   "metadata": {},
   "source": [
    "### 开始使用"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f64b771-f659-4c10-8472-d991bb6d6423",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sentencepiece as spm\n",
    "sp_bpe = spm.SentencePieceProcessor() \n",
    "sp_bpe.load('bpe_test.model')\n",
    "print('*** BPE ***')\n",
    "print(sp_bpe.encode_as_pieces('The excellence of a translation can only be judged by noting'))\n",
    "print(len(sp_bpe.encode_as_pieces('The excellence of a translation can only be judged by noting')))\n",
    "print(sp_bpe.encode_as_pieces('麒麟，是中国古代神话中的一种瑞兽'))\n",
    "print(len(sp_bpe.encode_as_pieces('麒麟，是中国古代神话中的一种瑞兽')))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6e47c652-b5eb-4973-b19a-3a5f35708110",
   "metadata": {},
   "source": [
    "### 合并LLaMa词表"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c222b17b-8e8f-4823-86d6-b0e8a13f3e58",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION\"]=\"python\"\n",
    "from transformers import LlamaTokenizer\n",
    "from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model\n",
    "import sentencepiece as spm\n",
    "\n",
    "# 位置\n",
    "llama_tokenizer_dir = \"/path/llama-2-7b-hf\" # 换成你自己模型的位置\n",
    "chinese_sp_model_file =\"/path/bpe_test.model\" # 刚才训练的模型\n",
    "\n",
    "# 加载\n",
    "llama_tokenizer = LlamaTokenizer.from_pretrained(llama_tokenizer_dir)\n",
    "chinese_sp_model = spm.SentencePieceProcessor()\n",
    "chinese_sp_model.Load(chinese_sp_model_file)\n",
    "llama_spm = sp_pb2_model.ModelProto()\n",
    "llama_spm.ParseFromString(llama_tokenizer.sp_model.serialized_model_proto())\n",
    "chinese_spm = sp_pb2_model.ModelProto()\n",
    "chinese_spm.ParseFromString(chinese_sp_model.serialized_model_proto())\n",
    "\n",
    "\n",
    "# 打印两个词表的大小和原llama的特殊token\n",
    "print(len(llama_tokenizer),len(chinese_sp_model))\n",
    "print(llama_tokenizer.all_special_tokens)\n",
    "print(llama_tokenizer.all_special_ids)\n",
    "print(llama_tokenizer.special_tokens_map)\n",
    "\n",
    "# 结果\n",
    "32000 10000\n",
    "['<s>', '</s>', '<unk>']\n",
    "[1, 2, 0]\n",
    "{'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>'}\n",
    "\n",
    "# 开始往llama词表里添加，这里你也可以直接加入你想要加入词表的词，或者是领域内的特殊词\n",
    "llama_spm_tokens_set=set(p.piece for p in llama_spm.pieces)\n",
    "print(len(llama_spm_tokens_set))\n",
    "print(f\"Before:{len(llama_spm_tokens_set)}\")\n",
    "for p in chinese_spm.pieces:\n",
    "    piece = p.piece\n",
    "    if piece not in llama_spm_tokens_set:\n",
    "        new_p = sp_pb2_model.ModelProto().SentencePiece()\n",
    "        new_p.piece = piece\n",
    "        new_p.score = 0\n",
    "        llama_spm.pieces.append(new_p)\n",
    "print(f\"New model pieces: {len(llama_spm.pieces)}\")\n",
    "\n",
    "# 结果\n",
    "32000\n",
    "Before:32000\n",
    "New model pieces: 40114\n",
    "# 我们中文词表原来有1万，去重添加后，添加了8114个词。\n",
    "\n",
    "# 保存合并后的模型\n",
    "output_sp_dir = 'merged_tokenizer_sp_test'\n",
    "output_hf_dir = 'merged_tokenizer_hf_test'\n",
    "os.makedirs(output_sp_dir,exist_ok=True)\n",
    "with open(output_sp_dir+'/chinese_llama.model', 'wb') as f:\n",
    "    f.write(llama_spm.SerializeToString())\n",
    "tokenizer = LlamaTokenizer(vocab_file=output_sp_dir+'/chinese_llama.model')\n",
    "\n",
    "tokenizer.save_pretrained(output_hf_dir)\n",
    "print(f\"Chinese-LLaMA tokenizer has been saved to {output_hf_dir}\")\n",
    "\n",
    "# 看一下效果\n",
    "llama_tokenizer = LlamaTokenizer.from_pretrained(llama_tokenizer_dir)\n",
    "chinese_llama_tokenizer = LlamaTokenizer.from_pretrained(output_hf_dir)\n",
    "\n",
    "\n",
    "text = \"The excellence of a translation can only be judged by noting\"\n",
    "print(\"Test text:\\n\",text)\n",
    "print(f\"Tokenized by LLaMA tokenizer:{llama_tokenizer.tokenize(text)}\")\n",
    "print(f\"Tokenized length by LLaMA tokenizer:{len(llama_tokenizer.tokenize(text))}\")\n",
    "print(f\"Tokenized by chinese_llama tokenizer:{chinese_llama_tokenizer.tokenize(text)}\")\n",
    "print(f\"Tokenized length by LLaMA-extent-1 tokenizer:{len(chinese_llama_tokenizer.tokenize(text))}\")\n",
    "\n",
    "\n",
    "text = \"麒麟，是中国古代神话中的一种瑞兽\"\n",
    "print(\"Test text:\\n\",text)\n",
    "print(f\"Tokenized by LLaMA tokenizer:{llama_tokenizer.tokenize(text)}\")\n",
    "print(f\"Tokenized length by LLaMA tokenizer:{len(llama_tokenizer.tokenize(text))}\")\n",
    "print(f\"Tokenized by chinese_llama tokenizer:{chinese_llama_tokenizer.tokenize(text)}\")\n",
    "print(f\"Tokenized length by chinese_llama tokenizer:{len(chinese_llama_tokenizer.tokenize(text))}\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
